# -*- coding: utf-8 -*-
import time
import csv
import os
import argparse
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

from backend import *
from datasets import *
from utils import *

torch.manual_seed(0)

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmbda', default=1e-3, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=300, type=int)
    parser.add_argument('--backend', default='vgg16', type=str) # vgg16 | resnet18
    parser.add_argument('--dataset_name', default='cifar10', type=str) # cifar10 | mnist
    return parser.parse_args()

class SProx(Optimizer):
    def __init__(self, params, alpha=required, lmbda = required):
        if alpha is not required and alpha < 0.0:
            raise ValueError("Invalid learning rate: {}".format(alpha))

        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))

        defaults = dict(alpha=alpha, lmbda=lmbda)
        super(SProx, self).__init__(params, defaults)

    def prox_mapping_l1(self, x, grad_f, lmbda, alpha):
        '''
            Calculate d for Omega(x) = ||x||_1
        '''
        trial_x  = torch.zeros_like(x)
        pos_shrink = x - alpha * grad_f - alpha * lmbda # new x is larger than alpha * lmbda
        neg_shrink = x - alpha * grad_f + alpha * lmbda # new x is less than -alpha * lmbda
        pos_shrink_idx = (pos_shrink > 0)
        neg_shrink_idx = (neg_shrink < 0)
        trial_x[pos_shrink_idx] = pos_shrink[pos_shrink_idx]
        trial_x[neg_shrink_idx] = neg_shrink[neg_shrink_idx]
        d = trial_x - x

        return d

    def prox_mapping_group(self, x, grad_f, lmbda, alpha):
        '''
            Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - alpha * grad_f
        #delta = torch.zeros(x.shape).to(self.device)
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = alpha * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta

    
    def __setstate__(self, state):
        super(SProx, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                
                if is_conv_weights(p.shape): # weights
                    delta = self.prox_mapping_group(p.data, grad_f, group['lmbda'], group['alpha'])
                    p.data.add_(1, delta)
                else: # bias
                    p.data.add_(-group['alpha'], grad_f) 
                    
        return loss
    
    def adjust_learning_rate(self, epoch):
        if epoch % 75 == 0 and epoch > 0:
            for group in self.param_groups:
                group['lr'] /= float(10)


if __name__ == "__main__":

    args = ParseArgs()
    lmbda = args.lmbda
    max_epoch = args.max_epoch
    backend = args.backend
    dataset_name = args.dataset_name
    alpha = 1e-1
    batch_size = 128

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    trainloader, testloader = Dataset(dataset_name)
    model = Model(backend, device)

    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size

    n = num_features
    m = num_samples

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = SProx(model.parameters(), alpha=alpha, lmbda=lmbda)

    # print('Accuracy:', check_accuracy(model, testloader))

    os.makedirs('results', exist_ok=True)
    csvname = 'results/proxsg_%s_%s_%E.csv'%(backend, dataset_name, lmbda)
    print('The csv file is %s'%csvname)
    # if os.path.exists(csvname):
    #     print('csvfile exists. Quit the program...')
    #     exit()
        
    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc', 'train_time', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    alg_start_time = time.time()

    epoch = 0
    while True:
        epoch_start_time = time.time()

        if epoch >= max_epoch:
            break

        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)
            y_pred = model.forward(X)

            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            optimizer.step()

        epoch += 1
        optimizer.adjust_learning_rate(epoch)
        
        train_time = time.time() - epoch_start_time
        #F = compute_F(trainloader, model, weights, criterion, lmbda)
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        #nnz = sum([torch.sum(w != 0).item() for w in weights]) / num_features
        sparsity, sparsity_tol, sparsity_group, _ = compute_sparsity(weights)
        accuracy = check_accuracy(model, testloader)
        #writer.writerow({'epoch': epoch, 'F_value': F, 'nnz': nnz, 'validation_acc': accuracy, 'train_time': train_time, 'remarks': '%s;%s;%E'%(backward, dataset_name, lmbda)})


        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, 'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group, 'validation_acc': accuracy, 'train_time': train_time, 'remarks': '%s;%s;%E'%(backend, dataset_name, lmbda)})



        csvfile.flush()
        print("epoch {}: {:2f}seconds ...".format(epoch, train_time))

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})

    csvfile.close()

